
# coding: utf-8

import os
import models
import torch as t
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
# import torchvision.models as models
from torchvision import transforms as T
from torch.utils import data
from PIL import Image


base_root = '/home/vivien/code/data/ActiveVisionDataset/'
dataset_name = 'Home_001_1'
rgb_prefix = 'jpg_rgb'
depth_prefix = 'high_res_depth'
jpg_sample_name = '000110{:0>5}0101.jpg'
depth_sample_name = '000110{:0>5}0103.png'
rgb_file_name = os.path.join(base_root,dataset_name,rgb_prefix,jpg_sample_name.format(1))
depth_file_name = os.path.join(base_root,dataset_name,depth_prefix,depth_sample_name.format(1))
mat_file_name = os.path.join(base_root,dataset_name,'image_structs.mat')


rgb = Image.open(rgb_file_name)
depth = Image.open(depth_file_name)
depth_array = np.array(depth)

resnet50 = models.resnet50(pretrained=True)

transforms = T.Compose([T.ToTensor()])
data = transforms(rgb)
batch_data = data.unsqueeze(0)
img_feature = resnet50.middle(batch_data)

mat_data = sio.loadmat(mat_file_name)
image_structs = mat_data['image_structs']
scale = mat_data['scale'].squeeze()
struct1 = image_structs[0,1]
K = np.array(struct1[2])

S = 10.5
array = [[(i,j) for i in range(depth_array.shape[1])] for j in range(depth_array.shape[0])]
loc_array = np.array(array)
depth_expand = np.expand_dims(depth_array,2)
data = np.concatenate((loc_array,depth_expand),axis=2)
data[:,:,2] = data[:,:,2] / scale
data = np.expand_dims(data,3)
projected_p = np.matmul(K,data)
projected_p = projected_p / scale
projected_p = projected_p.squeeze(3)
projected_p[:,:,0] = projected_p[:,:,0]*(S-1)/2 + (S+1)/2
projected_p[:,:,2] = projected_p[:,:,2]*(S-1)/2 + (S+1)/2
projected_p[:,:,0] = np.ceil(projected_p[:,:,0])
projected_p[:,:,2] = np.ceil(projected_p[:,:,2])
projected_p = projected_p.astype(np.int)


o_t = np.ceil(projected_p[:,:,0].max()).astype(np.int).tolist()
o_k = np.ceil(projected_p[:,:,2].max()).astype(np.int).tolist()
o_size = max(o_t, o_k)
o_feature = t.zeros((img_feature.shape[1], o_size, o_size))


for i in range(o_size):
    for j in range(o_size):
        select_projected_p = np.logical_and(projected_p[:,:,0]==i,projected_p[:,:,2]==j)
        loc_image_1,loc_image_2 = np.where(select_projected_p==True)
        loc_feature_1 = (loc_image_1/8).astype(np.int)
        loc_feature_2 = (loc_image_2/8).astype(np.int)
        tuple_loc = set(zip(loc_feature_1,loc_feature_2))
        list_loc = list(tuple_loc)
        list_loc = np.array(list_loc)
        if list_loc.shape[0] == 0:
            continue
        loc_feature_1 = list_loc[:,0]
        loc_feature_2 = list_loc[:,1]
        feature_selected = img_feature[:,:,loc_feature_1,loc_feature_2]
        feature_selected = feature_selected.max(2)[0]
        o_feature[:,i,j] = feature_selected

o_feature = o_feature.unsqueeze(0)


import torch.nn as nn

class GroundCNN(nn.Module):

    def __init__(self, inplanes, planes):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out


# In[75]:


ground_cnn = GroundCNN(o_feature.shape[1],64)


# In[76]:


observation = ground_cnn(o_feature)


# In[79]:


from torchvision import transforms
transforms.RandomRotation
ffgogoo
